{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VNNGP: Variational Nearest Neighbor Gaussian Procceses\n",
    "\n",
    "## Overview\n",
    "\n",
    "In this notebook, we will give an overview of how to use variational nearest neighbor Gaussian processes (VNNGP) (https://arxiv.org/abs/2202.01694) to rapidly train on the `elevators` UCI dataset.\n",
    "\n",
    "Similar to SVGP (https://arxiv.org/abs/1309.6835), VNNGP is a variational inducing point-based approach. Unlike SVGP that is typically limited to thousands of inducing points, VNNGP makes an additional approximation: it assumes that every inducing point and data point only depends on $\\leq K$ other inducing points. This is advantageous for multiple reasons:\n",
    "- The variational KL divergence term affords an unbiased stochastic estimate from a **minibatch of inducing points** \n",
    "- Consequentially, an unbiased estimate of the ELBO can be computed in $O(K^3)$ time.\n",
    "\n",
    "With this scalability, we recommend using $M=N$ inducing points, placing inducing points at every observed input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Make plots inline\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "smoke_test = ('CI' in os.environ)\n",
    "\n",
    "if not smoke_test and not os.path.isfile('../elevators.mat'):\n",
    "    print('Downloading \\'elevators\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')\n",
    "\n",
    "\n",
    "if smoke_test:  # this is for running the notebook in our testing framework\n",
    "    X, y = torch.randn(100, 3), torch.randn(100)\n",
    "else:\n",
    "    data = torch.Tensor(loadmat('../elevators.mat')['data'])\n",
    "    X = data[:1000, :-1]\n",
    "    X = X - X.min(0)[0]\n",
    "    X = 2 * (X / X.max(0)[0].clamp_min(1e-6)) - 1\n",
    "    y = data[:1000, -1]\n",
    "    y = y.sub(y.mean()).div(y.std())\n",
    "\n",
    "\n",
    "train_n = int(floor(0.8 * len(X)))\n",
    "train_x = X[:train_n, :].contiguous()\n",
    "train_y = y[:train_n].contiguous()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous()\n",
    "test_y = y[train_n:].contiguous()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a VNNGP model\n",
    "\n",
    "Creating a VNNGP model is similar to creating [other variational models](./SVGP_Regression_CUDA.ipynb). Here are some things that are specific to VNNGP:\n",
    "\n",
    "1. You need to use the `gpytorch.variational.NNVariationalStrategy` variational strategy class.\n",
    "2. You need to use the `gpytorch.variational.MeanFieldVariationalDistribution` variational distribution class.\n",
    "3. `inducing_points` should be set to `train_x`. This results in fast optimization and accurate predictions. \n",
    "4. There are two hyperparameters that you need to specify:\n",
    "  - `k`: number of nearest neighbors used. The higher the `k` is, the better the approximation accuracy is, but also more computations are needed. Default value is 256. \n",
    "  - `training_batch_size`: the mini-batch size of inducing points used in stochastic optimization. Default value is 256. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import ApproximateGP\n",
    "from gpytorch.variational.nearest_neighbor_variational_strategy import NNVariationalStrategy\n",
    "\n",
    "\n",
    "class GPModel(ApproximateGP):\n",
    "    def __init__(self, inducing_points, likelihood, k=256, training_batch_size=256):\n",
    "\n",
    "        m, d = inducing_points.shape\n",
    "        self.m = m\n",
    "        self.k = k\n",
    "\n",
    "        variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(m)\n",
    "\n",
    "        if torch.cuda.is_available():\n",
    "            inducing_points = inducing_points.cuda()\n",
    "\n",
    "        variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, k=k,\n",
    "                                                     training_batch_size=training_batch_size)\n",
    "        super(GPModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ZeroMean()\n",
    "        self.covar_module = gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=d)\n",
    "        \n",
    "        self.likelihood = likelihood\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "    def __call__(self, x, prior=False, **kwargs):\n",
    "        if x is not None:\n",
    "            if x.dim() == 1:\n",
    "                x = x.unsqueeze(-1)\n",
    "        return self.variational_strategy(x=x, prior=False, **kwargs)\n",
    "    \n",
    "if smoke_test:\n",
    "    k = 32\n",
    "    training_batch_size = 32\n",
    "else:\n",
    "    k = 256\n",
    "    training_batch_size = 64\n",
    "\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "# Note: one should use full training set as inducing points!\n",
    "model = GPModel(inducing_points=train_x, likelihood=likelihood, k=k, training_batch_size=training_batch_size)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    likelihood = likelihood.cuda()\n",
    "    model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the Model (mode 1, recommended)\n",
    "\n",
    "The cell below trains the model above, learning both the hyperparameters of the Gaussian process and the parameters of the neural network in an end-to-end fashion using Type-II MLE. VNNGP's training objective is to miminize the negative ELBO. Note that in the beginning, we introduce that VNNGP is ameanable to stochastic optimization over both data points and inducing points. That means in each iteration, we will sample a mini-batch of data points and a mini-batch of inducing points, compute the stochastic ELBO estimate, and then take a gradient step to update the model parameters. \n",
    "\n",
    "There are **two training modes available**. In this section we will introduce the mode 1, which is what we recommended in practice and what is implemented in experiments of the original paper. Since VNNGP sets inducing point locations to observed input locations, `inducing points` are essentially the `train_x`. Therefore, there is no need to separately iterate over training data and inducing points. As a result, we could just sample a mini-batch of inducing points, which would be treated as a mini-batch of training data as well. In this case, the mini-batch size for training data is the same as that for inducing points, which is `training_batch_size` we set above. \n",
    "\n",
    "While we recommend this training mode as it yields faster training, we do provide another training mode that allows users to use different mini-batches of training data and inducing points. See the last part of the notebook. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fda2d925f7e9437a96d170caae664494",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=20.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 1 if smoke_test else 20\n",
    "num_batches = model.variational_strategy._total_training_batches\n",
    "\n",
    "\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "# Our loss object. We're using the VariationalELBO\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))\n",
    "\n",
    "\n",
    "epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=\"Epoch\")\n",
    "for epoch in epochs_iter:\n",
    "    minibatch_iter = tqdm.notebook.tqdm(range(num_batches), desc=\"Minibatch\", leave=False)\n",
    "    \n",
    "    for i in minibatch_iter:\n",
    "        optimizer.zero_grad()\n",
    "        output = model(x=None)\n",
    "        # Obtain the indices for mini-batch data\n",
    "        current_training_indices = model.variational_strategy.current_training_indices\n",
    "        # Obtain the y_batch using indices. It is important to keep the same order of train_x and train_y\n",
    "        y_batch = train_y[...,current_training_indices]\n",
    "        if torch.cuda.is_available():\n",
    "            y_batch = y_batch.cuda()\n",
    "        loss = -mll(output, y_batch)\n",
    "        minibatch_iter.set_postfix(loss=loss.item())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "\n",
    "test_dataset = TensorDataset(test_x, test_y)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "means = torch.tensor([0.])\n",
    "test_mse = 0\n",
    "with torch.no_grad():\n",
    "    for x_batch, y_batch in test_loader:\n",
    "        preds = model(x_batch)\n",
    "        means = torch.cat([means, preds.mean.cpu()])\n",
    "         \n",
    "        diff = torch.pow(preds.mean - y_batch, 2)\n",
    "        diff = diff.sum(dim=-1) / test_x.size(0) # sum over bsz and scaling\n",
    "        diff = diff.mean() # average over likelihood_nsamples\n",
    "        test_mse += diff\n",
    "means = means[1:]\n",
    "test_rmse = test_mse.sqrt().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "\n",
    "test_dataset = TensorDataset(test_x, test_y)\n",
    "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "means = torch.tensor([0.])\n",
    "test_mse = 0\n",
    "with torch.no_grad():\n",
    "    for x_batch, y_batch in test_loader:\n",
    "        preds = model(x_batch)\n",
    "        means = torch.cat([means, preds.mean.cpu()])\n",
    "        \n",
    "        diff = torch.pow(preds.mean - y_batch, 2)\n",
    "        diff = diff.sum(dim=-1) / test_x.size(0) # sum over bsz and scaling\n",
    "        diff = diff.mean() # average over likelihood_nsamples\n",
    "        test_mse += diff\n",
    "means = means[1:]\n",
    "test_rmse = test_mse.sqrt().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8139828443527222\n"
     ]
    }
   ],
   "source": [
    "print(test_rmse)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the Model (mode 2)\n",
    "\n",
    "In this mode, users are able to sample separate mini-batches for training data and inducing points. Note that this will yield a slower training speed, since every iteration requires finding the set of inducing points that matches the current batch of training data for calculating ELBO. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# instantiate the model\n",
    "if smoke_test:\n",
    "    k = 32\n",
    "    training_batch_size = 32\n",
    "else:\n",
    "    k = 256\n",
    "    training_batch_size = 256\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "model = GPModel(inducing_points=train_x, likelihood=likelihood, k=k, training_batch_size=training_batch_size)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    likelihood = likelihood.cuda()\n",
    "    model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26688e5d3bc549718c4e510c56250bee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=20.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Minibatch'), FloatProgress(value=0.0, max=7.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# prepare for dataset\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "# this batch-size does not need to match the training-batch-size specified above\n",
    "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "\n",
    "num_epochs = 1 if smoke_test else 20\n",
    "\n",
    "\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "# Our loss object. We're using the VariationalELBO\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))\n",
    "\n",
    "\n",
    "epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=\"Epoch\")\n",
    "for i in epochs_iter:\n",
    "    # Within each iteration, we will go over each minibatch of data\n",
    "    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc=\"Minibatch\", leave=False)\n",
    "    for x_batch, y_batch in minibatch_iter:\n",
    "        optimizer.zero_grad()\n",
    "        output = model(x_batch)\n",
    "        loss = -mll(output, y_batch)\n",
    "        minibatch_iter.set_postfix(loss=loss.item())\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "means = torch.tensor([0.])\n",
    "test_mse = 0\n",
    "with torch.no_grad():\n",
    "    for x_batch, y_batch in test_loader:\n",
    "        preds = model(x_batch)\n",
    "        means = torch.cat([means, preds.mean.cpu()])\n",
    "         \n",
    "        diff = torch.pow(preds.mean - y_batch, 2)\n",
    "        diff = diff.sum(dim=-1) / test_x.size(0) # sum over bsz and scaling\n",
    "        diff = diff.mean() # average over likelihood_nsamples\n",
    "        test_mse += diff\n",
    "means = means[1:]\n",
    "test_rmse = test_mse.sqrt().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8090996742248535\n"
     ]
    }
   ],
   "source": [
    "print(test_rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
